#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
#


from typing import Tuple, Union

from torch.nn import Module

from lightly_train._models.dinov2_vit.dinov2_vit_src.models import (
    vision_transformer as vits,
)


def build_model(
    args, only_teacher=False, img_size=224
) -> Union[Tuple[Module, int], Tuple[Module, Module, int]]:
    suffix = "_memeff"
    if args.arch.endswith(suffix):
        args.arch = args.arch[: -len(suffix)]
    if "vit" in args.arch:
        vit_kwargs = dict(
            img_size=img_size,
            patch_size=args.patch_size,
            init_values=args.layerscale,
            ffn_layer=args.ffn_layer,
            block_chunks=args.block_chunks,
            qkv_bias=args.qkv_bias,
            proj_bias=args.proj_bias,
            ffn_bias=args.ffn_bias,
            num_register_tokens=args.num_register_tokens,
            interpolate_offset=args.interpolate_offset,
            interpolate_antialias=args.interpolate_antialias,
        )
        teacher = vits.__dict__[args.arch](**vit_kwargs)
        if only_teacher:
            return teacher, teacher.embed_dim
        student = vits.__dict__[args.arch](
            **vit_kwargs,
            drop_path_rate=args.drop_path_rate,
            drop_path_uniform=args.drop_path_uniform,
        )
        embed_dim = student.embed_dim
    return student, teacher, embed_dim


def build_model_from_cfg(cfg, only_teacher=False) -> Union[Tuple[Module, Module, int],]:
    return build_model(
        cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size
    )
